{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "# Load environment variables from openai.env file\n",
    "load_dotenv(\"openai.env\")\n",
    "\n",
    "# Read the OPENAI_API_KEY from the environment\n",
    "api_key = os.getenv(\"OPENAI_API_KEY\")\n",
    "api_base = os.getenv(\"OPENAI_API_BASE\")\n",
    "os.environ[\"OPENAI_API_KEY\"] = api_key\n",
    "os.environ[\"OPENAI_API_BASE\"] = api_base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChatDOc:和文件聊天"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tomiezhang/.pyenv/versions/3.10.12/envs/langchains/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The class `langchain_community.embeddings.openai.OpenAIEmbeddings` was deprecated in langchain-community 0.0.9 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import OpenAIEmbeddings`.\n",
      "  warn_deprecated(\n",
      "/Users/tomiezhang/.pyenv/versions/3.10.12/envs/langchains/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The class `langchain_community.chat_models.openai.ChatOpenAI` was deprecated in langchain-community 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import ChatOpenAI`.\n",
      "  warn_deprecated(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AIMessage(content='公司的注册地址是江苏省南京市雨花台区软件大道101号。', response_metadata={'finish_reason': 'stop', 'logprobs': None})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#导入必须的包\n",
    "from langchain.document_loaders import UnstructuredExcelLoader,Docx2txtLoader,PyPDFLoader\n",
    "from langchain.text_splitter import  CharacterTextSplitter\n",
    "from langchain.embeddings import  OpenAIEmbeddings\n",
    "from langchain.vectorstores import  Chroma\n",
    "#导入聊天所需的模块\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "\n",
    "\n",
    "#定义chatdoc\n",
    "class ChatDoc():\n",
    "    def __init__(self):\n",
    "        self.doc = None\n",
    "        self.splitText = [] #分割后的文本\n",
    "        self.template = [\n",
    "            (\"system\",\"你是一个处理文档的秘书,你从不说自己是一个大模型或者AI助手,你会根据下面提供的上下文内容来继续回答问题.\\n 上下文内容\\n {context} \\n\"),\n",
    "            (\"human\",\"你好！\"),\n",
    "            (\"ai\",\"你好\"),\n",
    "            (\"human\",\"{question}\"),\n",
    "        ]\n",
    "        self.prompt = ChatPromptTemplate.from_messages(self.template)\n",
    "\n",
    "    def getFile(self):\n",
    "        doc = self.doc\n",
    "        loaders = {\n",
    "            \"docx\":Docx2txtLoader,\n",
    "            \"pdf\":PyPDFLoader,\n",
    "            \"xlsx\":UnstructuredExcelLoader,\n",
    "        }\n",
    "        file_extension = doc.split(\".\")[-1]\n",
    "        loader_class = loaders.get(file_extension)\n",
    "        if loader_class:\n",
    "            try:\n",
    "                loader = loader_class(doc)\n",
    "                text = loader.load()\n",
    "                return text\n",
    "            except Exception as e: \n",
    "                print(f\"Error loading {file_extension} files:{e}\") \n",
    "        else:\n",
    "             print(f\"Unsupported file extension: {file_extension}\")\n",
    "             return  None \n",
    "\n",
    "    #处理文档的函数\n",
    "    def splitSentences(self):\n",
    "        full_text = self.getFile() #获取文档内容\n",
    "        if full_text != None:\n",
    "            #对文档进行分割\n",
    "            text_split = CharacterTextSplitter(\n",
    "                chunk_size=150,\n",
    "                chunk_overlap=20,\n",
    "            )\n",
    "            texts = text_split.split_documents(full_text)\n",
    "            self.splitText = texts\n",
    "    \n",
    "    #向量化与向量存储\n",
    "    def embeddingAndVectorDB(self):\n",
    "        embeddings = OpenAIEmbeddings()\n",
    "        db =Chroma.from_documents(\n",
    "            documents = self.splitText,\n",
    "            embedding = embeddings,\n",
    "        )\n",
    "        return db\n",
    "    \n",
    "    #提问并找到相关的文本块\n",
    "    def askAndFindFiles(self,question):\n",
    "        db = self.embeddingAndVectorDB()\n",
    "        #retriever = db.as_retriever(search_type=\"mmr\")\n",
    "        retriever = db.as_retriever(search_type=\"similarity_score_threshold\",search_kwargs={\"score_threshold\":.5,\"k\":1})\n",
    "        return retriever.get_relevant_documents(query=question)\n",
    "    \n",
    "    #用自然语言和文档聊天\n",
    "    def chatWithDoc(self,question):\n",
    "        _content = \"\"\n",
    "        context = self.askAndFindFiles(question)\n",
    "        for i in context:\n",
    "            _content += i.page_content\n",
    "        \n",
    "        messages = self.prompt.format_messages(context=_content,question=question)\n",
    "        chat = ChatOpenAI(\n",
    "            model=\"gpt-3.5-turbo\",\n",
    "            temperature=0,\n",
    "        )\n",
    "        return chat.invoke(messages)\n",
    "\n",
    "chat_doc = ChatDoc()\n",
    "chat_doc.doc = \"example/fake.docx\"\n",
    "chat_doc.splitSentences()\n",
    "chat_doc.chatWithDoc(\"公司注册地址是哪里？\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchains",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
